Master's thesis case study 3: Bandit's with stopping¶

In [1]:
%load_ext autoreload
%autoreload 2
In [2]:
import numpy
import torch
from adaptive_nof1 import *
from adaptive_nof1.policies import *
from adaptive_nof1.helpers import *
from adaptive_nof1.inference import *
from adaptive_nof1.metrics import *
from matplotlib import pyplot as plt
import seaborn
WARNING (pytensor.tensor.blas): Using NumPy C-API based implementation for BLAS functions.
In [3]:
# Setup generic n-of-1 parameters
block_length = 5
max_length = 10 * block_length
number_of_actions = 2
number_of_patients = 100
In [4]:
# Scenarios
class NormalModel(Model):
    def __init__(self, patient_id, mean, variance):
        self.rng = numpy.random.default_rng(patient_id)
        self.mean = mean
        self.variance = variance
        self.patient_id = patient_id

    def multivariate_normal_distribution(debug_data):
        cov = torch.diag_embed(torch.tensor(self.variance))
        return torch.distributions.MultivariateNormal(torch.tensor(self.mean), cov)

    def generate_context(self, history):
        return {}

    @property
    def additional_config(self):
        return {"expectations_of_interventions": self.mean}

    @property
    def number_of_interventions(self):
        return len(self.mean)

    def observe_outcome(self, action, context):
        treatment_index = action["treatment"]
        return {"outcome": self.rng.normal(self.mean[treatment_index], self.variance[treatment_index])}

    def __str__(self):
        return f"NormalModel({self.mean, self.variance})"

generating_scenario_I = lambda patient_id: NormalModel(patient_id, mean=[0, 0], variance=[1,1])
generating_scenario_II = lambda patient_id: NormalModel(patient_id, mean=[1, 0], variance=[1,1])
generating_scenario_III = lambda patient_id: NormalModel(patient_id, mean=[2, 0], variance=[1,1])
In [5]:
# Inference Model
inference_model = lambda: NormalKnownVariance(prior_mean=0, prior_variance=1, variance=1)

# Stopping Time
ALPHA_STOPPING = 0.01
def alpha_stopping_time(history, context):
    model = NormalKnownVariance(prior_mean=0, prior_variance=1, variance=1)
    model.update_posterior(history, number_of_actions)
    probabilities = model.approximate_max_probabilities(number_of_actions, context)
    return 1 - max(probabilities) < ALPHA_STOPPING
In [6]:
# Policies
fixed_policy = StoppingPolicy(
    policy = BlockPolicy(
        block_length = block_length,
        internal_policy = FixedPolicy(
            number_of_actions=2,
            inference_model = inference_model(),
        ),
    ),
    stopping_time = alpha_stopping_time,
)

explore_then_commit = StoppingPolicy(
    policy= BlockPolicy(
        block_length = block_length,
        internal_policy = ExploreThenCommit(
        number_of_actions=2,
        exploration_length=4,
        block_length = block_length,
        inference_model = inference_model(),
    
        ),
    ),
    stopping_time = alpha_stopping_time,
)


thompson_sampling_policy = StoppingPolicy(
        policy = BlockPolicy(
            block_length = block_length,
            internal_policy = ThompsonSampling(
                inference_model=inference_model(),
                number_of_actions=2,
            ),
        ),
    stopping_time = alpha_stopping_time,
)

ucb_policy = StoppingPolicy(
    policy = BlockPolicy(
        block_length = block_length,
        internal_policy = UpperConfidenceBound(
            inference_model=inference_model(),
            number_of_actions=2,
            epsilon=0.05,
        ),
    ),
    stopping_time = alpha_stopping_time,
)
In [7]:
# Full crossover study
study_designs = {
    "n_patients": [number_of_patients],
    "policy": [fixed_policy, explore_then_commit, thompson_sampling_policy, ucb_policy],
    "model_from_patient_id": [
        generating_scenario_I, generating_scenario_II, generating_scenario_III,
    ]
}
configurations = generate_configuration_cross_product(study_designs)
In [8]:
calculated_series, config_to_simulation_data = simulate_configurations(
    configurations, max_length
)
  0%|          | 0/50 [00:00<?, ?it/s]
  0%|          | 0/50 [00:00<?, ?it/s]
  0%|          | 0/50 [00:00<?, ?it/s]
  0%|          | 0/50 [00:00<?, ?it/s]
  0%|          | 0/50 [00:00<?, ?it/s]
  0%|          | 0/50 [00:00<?, ?it/s]
  0%|          | 0/50 [00:00<?, ?it/s]
  0%|          | 0/50 [00:00<?, ?it/s]
  0%|          | 0/50 [00:00<?, ?it/s]
  0%|          | 0/50 [00:00<?, ?it/s]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 1.95287569]
[2.10230224 3.28970725]
[1.68757613 3.28970725]
[1.73281824 3.28970725]
[1.89482357 3.28970725]
[3.28970725 1.78263585]
[2.11100749 3.28970725]
[1.69909177 3.28970725]
[3.28970725 0.7356268 ]
[3.28970725 1.84953347]
[3.28970725 1.48703369]
[3.28970725 2.2207267 ]
[2.60626885 3.28970725]
[2.10139188 3.28970725]
[1.06305465 3.28970725]
[3.28970725 1.5903371 ]
[2.57302619 3.28970725]
[3.28970725 1.54313908]
[1.92191388 3.28970725]
[3.28970725 2.10131065]
[3.28970725 2.41437005]
[2.20606519 3.28970725]
[3.28970725 1.40577887]
[1.72323411 3.28970725]
[1.91971212 3.28970725]
[1.51226677 3.28970725]
[1.76947968 3.28970725]
[3.28970725 2.38250763]
[1.66243489 3.28970725]
[1.81691371 3.28970725]
[2.06054612 3.28970725]
[1.96420362 3.28970725]
[2.19175514 3.28970725]
[3.28970725 1.73492909]
[2.3192803  3.28970725]
[3.28970725 2.41342734]
[1.78918069 3.28970725]
[2.04259023 3.28970725]
[2.4482715  3.28970725]
[3.28970725 1.6629994 ]
[3.28970725 1.37933528]
[3.28970725 1.61955171]
[3.28970725 1.75311485]
[1.49172375 3.28970725]
[3.28970725 2.55828373]
[1.87732424 3.28970725]
[3.28970725 1.91978472]
[3.28970725 0.94912561]
[2.38625532 3.28970725]
[2.05114842 3.28970725]
[3.28970725 2.07848305]
[3.28970725 1.88759929]
[3.28970725 1.8492947 ]
[1.60664441 3.28970725]
[2.13719975 3.28970725]
[3.28970725 1.54683802]
[3.28970725 2.3376584 ]
[3.28970725 2.55935095]
[3.28970725 1.12760385]
[2.17969147 3.28970725]
[1.39754359 3.28970725]
[3.28970725 1.34065998]
[1.20058969 3.28970725]
[3.28970725 2.33910739]
[1.89298032 3.28970725]
[1.60545794 3.28970725]
[3.28970725 1.43946271]
[1.98447902 3.28970725]
[3.28970725 2.71502871]
[2.1064876  3.28970725]
[3.28970725 2.04738365]
[2.06403838 3.28970725]
[1.33063265 3.28970725]
[3.28970725 1.90824339]
[3.28970725 1.93508305]
[3.28970725 2.0257046 ]
[2.54237517 3.28970725]
[2.17975653 3.28970725]
[1.74457456 3.28970725]
[2.09638938 3.28970725]
[1.51829405 3.28970725]
[3.28970725 2.31133553]
[2.30040541 3.28970725]
[3.28970725 1.49083153]
[1.17818228 3.28970725]
[1.62489326 3.28970725]
[1.69940162 3.28970725]
[3.28970725 2.20661127]
[1.5382256  3.28970725]
[2.17364767 3.28970725]
[2.15845385 3.28970725]
[3.28970725 1.85660762]
[0.82429786 3.28970725]
[3.28970725 2.36626936]
[3.28970725 2.13584086]
[2.12926729 3.28970725]
[3.28970725 1.49912721]
[1.62173627 3.28970725]
[1.80987343 3.28970725]
[1.68533444 3.28970725]
[2.02624912 1.95287569]
[2.10230224 2.11053662]
[1.73281824 1.91737436]
[1.89482357 1.611254  ]
[2.11163049 1.78263585]
[2.11100749 2.64304136]
[1.69909177 1.80166771]
[1.74767096 1.84953347]
[1.73330232 2.2207267 ]
[2.10139188 2.30381042]
[2.23221129 1.5903371 ]
[1.59601251 1.54313908]
[1.92191388 2.84914979]
[1.86102076 2.10131065]
[1.72323411 2.09355623]
[1.91971212 1.74047112]
[1.76947968 2.16002627]
[1.76004928 2.38250763]
[1.66243489 1.86662601]
[1.81691371 1.90788863]
[2.06054612 2.57566623]
[1.96420362 2.17478056]
[2.19175514 1.4523873 ]
[2.3192803  1.88114517]
[1.78918069 1.88086876]
[2.4482715  2.36225855]
[1.64481113 1.37933528]
[1.47997932 1.61955171]
[1.52559143 1.75311485]
[1.87732424 2.09580938]
[1.70446898 1.91978472]
[1.40944019 0.94912561]
[2.38625532 2.05004173]
[2.05114842 1.23511877]
[2.17641755 2.07848305]
[1.64283545 1.88759929]
[1.99808681 1.8492947 ]
[2.13719975 2.17345956]
[2.43327782 2.3376584 ]
[1.20058969 1.25234479]
[2.9694922  2.33910739]
[1.89298032 1.84571926]
[1.60545794 1.6712577 ]
[2.05169945 1.43946271]
[1.98447902 2.22525195]
[2.1064876  2.09326514]
[1.73038114 2.04738365]
[2.06403838 1.67873992]
[2.00469347 1.93508305]
[2.17975653 1.54663394]
[2.09638938 1.72174411]
[2.41144931 2.31133553]
[1.62489326 1.89182652]
[1.69940162 1.90144359]
[2.05437533 2.20661127]
[2.17364767 2.02901327]
[2.15845385 2.52187589]
[1.58988509 1.85660762]
[1.67909647 2.13584086]
[2.12926729 2.25337757]
[1.62173627 1.40613098]
[1.80987343 2.25089758]
[1.68533444 2.00482809]
[1.45546191 1.95287569]
[2.10230224 1.82555816]
[1.73281824 1.59974416]
[1.88892068 1.78263585]
[1.69909177 1.7296998 ]
[1.74767096 2.00002304]
[2.10139188 2.11807736]
[1.53382306 1.54313908]
[1.86102076 2.02638065]
[1.72323411 1.92435447]
[2.01659982 1.74047112]
[1.66243489 1.93983713]
[1.81691371 1.32841053]
[2.06054612 2.30183742]
[1.78918069 1.83713528]
[2.20127104 2.36225855]
[1.57372887 1.37933528]
[1.47997932 1.65577113]
[1.87732424 1.66195466]
[1.70446898 1.63232217]
[1.98065583 2.05004173]
[1.50174701 2.07848305]
[1.98978005 1.8492947 ]
[2.13719975 1.79635951]
[1.84616565 2.3376584 ]
[1.20058969 1.32280047]
[1.44770257 1.84571926]
[1.98447902 1.7331086 ]
[2.02162664 2.09326514]
[1.73038114 1.98219429]
[1.97301153 1.67873992]
[2.08266903 1.93508305]
[2.03758313 1.72174411]
[2.14193374 2.31133553]
[1.62489326 1.79524291]
[1.69940162 1.68381888]
[2.05437533 2.00918149]
[1.8872172  2.02901327]
[2.15845385 2.39983754]
[1.67909647 1.67441897]
[2.12926729 1.87552549]
[1.5104976  1.40613098]
[1.80987343 2.00936299]
[1.45546191 1.80025815]
[1.83137204 1.82555816]
[1.74833863 1.59974416]
[1.55662779 1.78263585]
[1.69909177 1.4346049 ]
[1.74767096 2.01881252]
[2.10139188 1.89434984]
[1.53382306 1.74301719]
[1.86102076 1.86095275]
[1.72323411 1.74201709]
[1.63135659 1.32841053]
[1.78918069 1.62235714]
[1.54959548 1.37933528]
[1.76059101 1.66195466]
[1.78097328 1.63232217]
[1.98065583 2.25689838]
[1.84181141 1.8492947 ]
[1.81677093 1.79635951]
[1.91809598 1.7331086 ]
[2.02162664 2.11153002]
[1.84854473 1.67873992]
[1.90262136 1.93508305]
[2.14193374 2.15972344]
[1.62489326 1.41831039]
[1.72172626 2.00918149]
[2.15845385 2.0209795 ]
[1.41075384 1.67441897]
[1.92568072 1.87552549]
[1.58256047 1.40613098]
[1.80987343 1.91561085]
[1.73072034 1.82555816]
[1.55662779 1.68338809]
[1.40932378 1.4346049 ]
[1.70355914 1.86095275]
[1.72323411 1.59888478]
[1.52677934 1.62235714]
[1.58164026 1.66195466]
[1.75941664 1.63232217]
[1.98065583 1.99357756]
[1.84181141 1.83241122]
[1.81810498 1.79635951]
[1.82636405 1.7331086 ]
[2.02162664 1.9906502 ]
[1.83525043 1.67873992]
[1.90262136 1.84632576]
[2.14193374 1.93735343]
[1.669262   1.41831039]
[1.72172626 1.62149241]
[1.92089956 2.0209795 ]
[1.82186837 1.87552549]
[1.80987343 1.97651927]
[1.73072034 1.64064982]
[1.55662779 1.53235852]
[1.40932378 1.333351  ]
[1.70355914 1.58722513]
[1.99370799 1.59888478]
[1.52677934 1.57759901]
[1.58164026 1.47487419]
[1.78926389 1.63232217]
[1.98065583 1.97807017]
[1.70095359 1.83241122]
[1.7484532 1.7331086]
[1.92583206 1.9906502 ]
[1.7634656  1.67873992]
[1.83170639 1.84632576]
[1.95902034 1.93735343]
[1.67747836 1.62149241]
[1.92089956 1.98973009]
[1.82186837 1.66252411]
[1.45102069 1.53235852]
[1.31203018 1.333351  ]
[1.57912507 1.58722513]
[1.36728217 1.47487419]
[1.91416668 1.97807017]
[1.63330179 1.7331086 ]
[1.92583206 1.96123597]
[1.7621056  1.67873992]
[1.83170639 1.87612481]
[1.92625841 1.93735343]
[1.64559877 1.62149241]
[1.92089956 1.76803282]
[1.71593963 1.66252411]
[1.45102069 1.38165296]
[1.31203018 1.36821809]
[1.57912507 1.47486883]
[1.91416668 2.0787459 ]
[1.92583206 1.73835468]
[1.81095685 1.67873992]
[1.83170639 1.92981191]
[1.92625841 1.94926123]
[1.68400988 1.62149241]
[1.88739486 1.76803282]
[1.69474064 1.66252411]
[1.31203018 1.42522878]
[1.63272255 1.47486883]
[1.74353077 1.67873992]
[1.83170639 1.8180798 ]
[1.92625841 1.86139177]
[1.65777542 1.62149241]
[1.64666145 1.76803282]
[1.75434398 1.66252411]
  0%|          | 0/50 [00:00<?, ?it/s]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 1.95287569]
[3.28970725 2.10230224]
[3.28970725 1.68757613]
[3.28970725 1.73281824]
[2.7281569  3.28970725]
[3.28970725 1.78263585]
[2.94434083 3.28970725]
[3.28970725 1.69909177]
[3.28970725 0.7356268 ]
[3.28970725 1.84953347]
[2.32036703 3.28970725]
[3.28970725 2.2207267 ]
[3.43960218 3.28970725]
[3.28970725 2.10139188]
[1.89638799 3.28970725]
[3.28970725 1.5903371 ]
[3.40635952 3.28970725]
[3.28970725 1.54313908]
[2.75524721 3.28970725]
[3.28970725 2.10131065]
[3.24770338 3.28970725]
[3.28970725 2.20606519]
[3.28970725 1.40577887]
[2.55656745 3.28970725]
[2.75304545 3.28970725]
[3.28970725 1.51226677]
[2.60281301 3.28970725]
[3.21584096 3.28970725]
[3.28970725 1.66243489]
[3.28970725 1.81691371]
[2.89387946 3.28970725]
[3.28970725 1.96420362]
[3.02508847 3.28970725]
[2.56826242 3.28970725]
[3.28970725 2.3192803 ]
[3.24676068 3.28970725]
[2.62251402 3.28970725]
[3.28970725 2.04259023]
[3.28160483 3.28970725]
[2.49633273 3.28970725]
[3.28970725 1.37933528]
[2.45288504 3.28970725]
[2.58644818 3.28970725]
[2.32505708 3.28970725]
[3.28970725 2.55828373]
[2.71065757 3.28970725]
[2.75311805 3.28970725]
[3.28970725 0.94912561]
[3.21958866 3.28970725]
[3.28970725 2.05114842]
[3.28970725 2.07848305]
[3.28970725 1.88759929]
[3.28970725 1.8492947 ]
[3.28970725 1.60664441]
[2.97053309 3.28970725]
[2.38017135 3.28970725]
[3.17099173 3.28970725]
[3.39268428 3.28970725]
[1.96093718 3.28970725]
[3.01302481 3.28970725]
[3.28970725 1.39754359]
[3.28970725 1.34065998]
[2.03392302 3.28970725]
[3.17244072 3.28970725]
[2.72631365 3.28970725]
[3.28970725 1.60545794]
[2.27279604 3.28970725]
[3.28970725 1.98447902]
[3.54836204 3.28970725]
[3.28970725 2.1064876 ]
[2.88071698 3.28970725]
[2.89737172 3.28970725]
[2.16396599 3.28970725]
[2.74157672 3.28970725]
[2.76841638 3.28970725]
[3.28970725 2.0257046 ]
[3.28970725 2.54237517]
[3.01308987 3.28970725]
[3.28970725 1.74457456]
[3.28970725 2.09638938]
[2.35162738 3.28970725]
[3.28970725 2.31133553]
[3.28970725 2.30040541]
[3.28970725 1.49083153]
[2.01151561 3.28970725]
[3.28970725 1.62489326]
[3.28970725 1.69940162]
[3.28970725 2.20661127]
[2.37155893 3.28970725]
[3.006981   3.28970725]
[2.99178718 3.28970725]
[3.28970725 1.85660762]
[1.65763119 3.28970725]
[3.19960269 3.28970725]
[2.96917419 3.28970725]
[2.96260062 3.28970725]
[3.28970725 1.49912721]
[2.45506961 3.28970725]
[3.28970725 1.80987343]
[2.51866778 3.28970725]
[2.75070769 1.73281824]
[2.94434083 2.64304136]
[2.32036703 2.2276421 ]
[2.87270141 3.28970725]
[1.89638799 2.24126513]
[2.84786216 3.28970725]
[2.75524721 2.84914979]
[2.69435409 2.10131065]
[2.4064104  2.20606519]
[2.89387946 2.57566623]
[2.56826242 2.52380305]
[2.71447851 2.3192803 ]
[2.50883207 2.04259023]
[2.49633273 2.33200857]
[2.32505708 2.19181503]
[2.59318286 2.55828373]
[2.06845211 2.05114842]
[2.38017135 2.43472707]
[3.04272342 3.28970725]
[1.96093718 1.8996458 ]
[3.17244072 2.9694922 ]
[2.27279604 2.05169945]
[3.18584349 3.28970725]
[2.16396599 2.24016872]
[2.80707162 2.54237517]
[2.55507744 2.09638938]
[2.35162738 2.57830882]
[2.44816857 2.30040541]
[2.01151561 2.06731494]
[2.37155893 2.21960381]
[2.99178718 2.52187589]
[1.65763119 1.89560817]
[2.96260062 2.25337757]
[2.63163201 2.64304136]
[2.20083775 2.2276421 ]
[2.75524721 2.43557852]
[2.470283   2.20606519]
[2.92995372 2.57566623]
[2.42745839 2.52380305]
[2.48647352 2.33200857]
[2.3895279  2.19181503]
[2.73213327 2.55828373]
[2.38017135 2.25598459]
[2.07922724 1.8996458 ]
[2.88058391 2.9694922 ]
[2.16396599 1.91148607]
[2.76792949 2.54237517]
[2.35162738 2.04195274]
[2.51355842 2.30040541]
[2.01151561 1.70160283]
[2.45618929 2.21960381]
[1.65763119 1.78583052]
[2.63163201 2.70935241]
[2.42745839 2.04952202]
[2.75726895 2.55828373]
[2.45988615 2.25598459]
[2.23967778 1.91148607]
[2.25202635 2.04195274]
[2.42429494 2.30040541]
[2.20367775 1.70160283]
[1.65763119 1.65821152]
[2.37613044 2.04195274]
[2.38163517 2.30040541]
[1.65763119 1.68156906]
[2.46982378 2.30040541]
[1.65763119 1.76957808]
[1.65763119 1.76896305]
[1.65763119 1.72210251]
[1.65763119 1.70304922]
  0%|          | 0/50 [00:00<?, ?it/s]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.28970725 3.28970725]
[3.61954236 3.28970725]
[3.28970725 2.10230224]
[3.28970725 1.68757613]
[3.28970725 1.73281824]
[3.28970725 1.89482357]
[3.44930252 3.28970725]
[3.77767416 3.28970725]
[3.36575843 3.28970725]
[2.40229346 3.28970725]
[3.28970725 1.84953347]
[3.15370036 3.28970725]
[3.28970725 2.2207267 ]
[4.27293551 3.28970725]
[3.28970725 2.10139188]
[2.72972132 3.28970725]
[3.25700376 3.28970725]
[3.28970725 2.57302619]
[3.28970725 1.54313908]
[3.28970725 1.92191388]
[3.76797731 3.28970725]
[3.28970725 2.41437005]
[3.28970725 2.20606519]
[3.28970725 1.40577887]
[3.38990078 3.28970725]
[3.28970725 1.91971212]
[3.28970725 1.51226677]
[3.28970725 1.76947968]
[4.04917429 3.28970725]
[3.32910156 3.28970725]
[3.28970725 1.81691371]
[3.72721279 3.28970725]
[3.28970725 1.96420362]
[3.85842181 3.28970725]
[3.28970725 1.73492909]
[3.28970725 2.3192803 ]
[4.08009401 3.28970725]
[3.45584735 3.28970725]
[3.7092569  3.28970725]
[3.28970725 2.4482715 ]
[3.32966607 3.28970725]
[3.28970725 1.37933528]
[3.28621837 3.28970725]
[3.28970725 1.75311485]
[3.15839042 3.28970725]
[4.22495039 3.28970725]
[3.28970725 1.87732424]
[3.58645139 3.28970725]
[2.61579228 3.28970725]
[4.05292199 3.28970725]
[3.71781508 3.28970725]
[3.74514972 3.28970725]
[3.55426596 3.28970725]
[3.28970725 1.8492947 ]
[3.27331107 3.28970725]
[3.80386642 3.28970725]
[3.28970725 1.54683802]
[3.28970725 2.3376584 ]
[3.28970725 2.55935095]
[3.28970725 1.12760385]
[3.28970725 2.17969147]
[3.28970725 1.39754359]
[3.00732665 3.28970725]
[2.86725635 3.28970725]
[3.28970725 2.33910739]
[3.55964698 3.28970725]
[3.27212461 3.28970725]
[3.10612937 3.28970725]
[3.65114568 3.28970725]
[4.38169537 3.28970725]
[3.28970725 2.1064876 ]
[3.28970725 2.04738365]
[3.73070505 3.28970725]
[2.99729932 3.28970725]
[3.57491006 3.28970725]
[3.60174972 3.28970725]
[3.69237126 3.28970725]
[4.20904183 3.28970725]
[3.8464232  3.28970725]
[3.28970725 1.74457456]
[3.76305605 3.28970725]
[3.18496072 3.28970725]
[3.9780022  3.28970725]
[3.96707207 3.28970725]
[3.15749819 3.28970725]
[3.28970725 1.17818228]
[3.28970725 1.62489326]
[3.36606829 3.28970725]
[3.87327794 3.28970725]
[3.20489226 3.28970725]
[3.84031433 3.28970725]
[3.28970725 2.15845385]
[3.52327428 3.28970725]
[2.49096453 3.28970725]
[4.03293602 3.28970725]
[3.28970725 2.13584086]
[3.28970725 2.12926729]
[3.28970725 1.49912721]
[3.28970725 1.62173627]
[3.28970725 1.80987343]
[3.35200111 3.28970725]
[3.68954924 3.28970725]
[3.6432628  3.28970725]
[4.11223508 3.28970725]
[3.42862269 3.28970725]
[2.72972132 2.24126513]
[3.6803892  3.28970725]
[3.19011963 2.41437005]
[3.60100316 3.28970725]
[3.44405983 3.28970725]
[4.04796063 3.28970725]
[3.50683158 3.28970725]
[3.82959375 3.28970725]
[3.52096268 3.28970725]
[3.54716605 3.28970725]
[3.69821278 3.28970725]
[3.49598318 3.28970725]
[3.93891592 3.28970725]
[3.3116269  3.28970725]
[3.8399724  3.28970725]
[3.4448092  3.28970725]
[3.27331107 2.73014367]
[3.87038624 3.28970725]
[3.5584082  3.28970725]
[3.81533442 3.28970725]
[3.56063296 3.28970725]
[3.37582657 3.28970725]
[3.66808653 3.28970725]
[3.9645547  3.28970725]
[3.55169415 3.28970725]
[3.60173579 3.28970725]
[3.18496072 2.57830882]
[3.65470333 3.28970725]
[3.48321491 3.28970725]
[3.84329203 3.28970725]
[3.81147803 3.28970725]
[3.39902263 3.28970725]
[2.49096453 1.89560817]
[3.87329257 3.28970725]
[3.53193345 3.28970725]
[3.40235174 3.28970725]
[3.63651471 3.28970725]
[3.49572132 3.28970725]
[3.76041277 3.28970725]
[3.63859979 3.28970725]
[3.62644441 3.28970725]
[3.60768729 3.28970725]
[3.76543712 3.28970725]
[3.60336656 3.28970725]
[3.53226252 3.28970725]
[3.62834705 3.28970725]
[3.43079065 3.28970725]
[3.68075754 3.28970725]
[3.48127551 3.28970725]
[3.62949303 3.28970725]
[3.35683354 3.28970725]
[3.60508509 3.28970725]
[3.65536619 3.28970725]
[3.35285869 3.28970725]
[3.8268844  3.28970725]
[3.8322275  3.28970725]
[3.62759041 3.28970725]
[3.71588574 3.28970725]
[3.63511674 3.28970725]
[3.46429438 3.28970725]
[3.82109632 3.28970725]
[3.7277351  3.28970725]
[3.60416946 3.28970725]
[3.92976818 3.28970725]
[3.779973   3.28970725]
[3.45348635 3.28970725]
[3.44343584 3.28970725]
[3.32659627 3.28970725]
[3.69769815 3.28970725]
[3.5677129  3.28970725]
[3.69031652 3.28970725]
[3.54772823 3.28970725]
[3.69820156 3.28970725]
[3.49538529 3.28970725]
[3.54007351 3.28970725]
[3.73237045 3.28970725]
[3.53602584 3.28970725]
[3.68440632 3.28970725]
[3.38243491 3.28970725]
[3.62701676 3.28970725]
[3.50125698 3.28970725]
[3.66064491 3.28970725]
[3.63616419 3.28970725]
[3.41070493 3.28970725]
[3.75060632 3.28970725]
[3.65900988 3.28970725]
[3.62076431 3.28970725]
[3.59008212 3.28970725]
[3.53816401 3.28970725]
[3.7684615  3.28970725]
[3.70239419 3.28970725]
[3.60848906 3.28970725]
[3.78927997 3.28970725]
[3.7449954  3.28970725]
[3.56058446 3.28970725]
[3.46667717 3.28970725]
[3.27563587 3.28970725]
[3.66248797 3.28970725]
[3.48562607 3.28970725]
[3.65344189 3.28970725]
[3.56393786 3.28970725]
[3.56075463 3.28970725]
[3.44086845 3.28970725]
[3.57416062 3.28970725]
[3.63569958 3.28970725]
[3.56986549 3.28970725]
[3.79903671 3.28970725]
[3.21643759 3.28970725]
[3.52507872 3.28970725]
[3.51364014 3.28970725]
[3.65370431 3.28970725]
[3.66626889 3.28970725]
[3.46872736 3.28970725]
[3.74853145 3.28970725]
[3.71109201 3.28970725]
[3.40397009 3.28970725]
[3.52316917 3.28970725]
[3.48181014 3.28970725]
[3.65535185 3.28970725]
[3.52281385 3.28970725]
[3.63121402 3.28970725]
[3.78902516 3.28970725]
[3.51582971 3.28970725]
[3.4236798  3.28970725]
[3.50907003 3.28970725]
[3.61992455 3.28970725]
[3.72811233 3.28970725]
[3.62575015 3.28970725]
[3.58565335 3.28970725]
[3.4398208  3.28970725]
[3.57069886 3.28970725]
[3.59009141 3.28970725]
[3.62065736 3.28970725]
[3.9843109  3.28970725]
[3.66424871 3.28970725]
[3.57887413 3.28970725]
[3.62877434 3.28970725]
[3.63331735 3.28970725]
[3.51627643 3.28970725]
[3.72534549 3.28970725]
[3.69196438 3.28970725]
[3.45556246 3.28970725]
[3.5990571  3.28970725]
[3.53961753 3.28970725]
[3.64753884 3.28970725]
[3.53205887 3.28970725]
[3.61413761 3.28970725]
[3.73555834 3.28970725]
[3.51645764 3.28970725]
[3.38054065 3.28970725]
[3.48077838 3.28970725]
[3.64564882 3.28970725]
[3.66218852 3.28970725]
[3.73363647 3.28970725]
[3.6228635  3.28970725]
[3.54646247 3.28970725]
[3.54835296 3.28970725]
[3.67278602 3.28970725]
[3.59552661 3.28970725]
[3.59628957 3.28970725]
[3.561294   3.28970725]
[3.56223149 3.28970725]
[3.65022238 3.28970725]
[3.49095276 3.28970725]
[3.75539233 3.28970725]
[3.69478159 3.28970725]
[3.53063714 3.28970725]
[3.66347642 3.28970725]
[3.55641383 3.28970725]
[3.67075326 3.28970725]
[3.5336596  3.28970725]
[3.47521144 3.28970725]
[3.80589616 3.28970725]
[3.57722147 3.28970725]
[3.32186438 3.28970725]
[3.42250935 3.28970725]
[3.56810322 3.28970725]
[3.66961415 3.28970725]
[3.74515388 3.28970725]
[3.72306898 3.28970725]
[3.54170229 3.28970725]
[3.55294894 3.28970725]
[3.76134607 3.28970725]
[3.64933276 3.28970725]
[3.58736738 3.28970725]
[3.71215351 3.28970725]
[3.49298415 3.28970725]
[3.70676046 3.28970725]
[3.56411594 3.28970725]
[3.79761408 3.28970725]
[3.62537905 3.28970725]
[3.52196116 3.28970725]
[3.67184792 3.28970725]
[3.57219259 3.28970725]
[3.66301659 3.28970725]
[3.57166898 3.28970725]
[3.44974749 3.28970725]
[3.82383064 3.28970725]
[3.64210542 3.28970725]
[3.3365258  3.28970725]
[3.46400122 3.28970725]
[3.62894845 3.28970725]
[3.63267062 3.28970725]
[3.69809931 3.28970725]
[3.76253778 3.28970725]
[3.57412259 3.28970725]
[3.53180688 3.28970725]
[3.70934412 3.28970725]
[3.72404677 3.28970725]
[3.58130391 3.28970725]
[3.71727782 3.28970725]
[3.46914003 3.28970725]
[3.657989   3.28970725]
[3.66790936 3.28970725]
[3.74883047 3.28970725]
[3.59957392 3.28970725]
[3.49797761 3.28970725]
[3.65785301 3.28970725]
[3.63599403 3.28970725]
[3.79911661 3.28970725]
[3.5662002  3.28970725]
[3.51004322 3.28970725]
[3.84401051 3.28970725]
In [10]:
# Todo: make the output table in a way that we chose the maximum index
def debug_data_to_torch_distribution(debug_data):
    mean = debug_data["mean"]
    # + the true variance of 1
    variance = numpy.array(debug_data["variance"]) + 1
    cov = torch.diag_embed(torch.tensor(variance))
    return torch.distributions.MultivariateNormal(torch.tensor(mean), cov)

def data_to_true_distribution(data):
    mean = data.additional_config["expectations_of_interventions"]
    cov = torch.eye(len(mean))
    return torch.distributions.MultivariateNormal(torch.tensor(mean), cov)


metrics = [
    SimpleRegretWithMean(),
    BestArmIdentification(),
    CumulativeRegret(),
    Length(),
    KLDivergence(data_to_true_distribution = data_to_true_distribution, debug_data_to_posterior_distribution=debug_data_to_torch_distribution),
]
model_mapping = {
    "NormalModel(([0, 0], [1, 1]))": "I",
    "NormalModel(([1, 0], [1, 1]))": "II",
    "NormalModel(([2, 0], [1, 1]))": "III",
}
policy_mapping = {
    "StoppingPolicy(BlockPolicy(FixedPolicy))": "Fixed",
    "StoppingPolicy(BlockPolicy(ThompsonSampling(NormalKnownVariance(0, 1, 1))))": "TS",
    "StoppingPolicy(BlockPolicy(UpperConfidenceBound(0.05 epsilon, NormalKnownVariance(0, 1, 1))))": "UCB",
    "StoppingPolicy(BlockPolicy(ExploreThenCommit(4,NormalKnownVariance(0, 1, 1))))": "ETC",
}

df = SeriesOfSimulationsData.score_data(
    [s["result"] for s in calculated_series], metrics, {"model": lambda x: model_mapping[x], "policy": lambda x: policy_mapping[x]}
)

df = df.reset_index(drop=True)
max_t_indices = df.groupby(["policy", "metric", "model", "patient_id"])["t"].idxmax()
filtered_df = df.iloc[max_t_indices]
filtered_df = filtered_df.reset_index(drop=True)
groupby_columns = ["model", "policy"]

pivoted_df = filtered_df.pivot(
    index=["model", "policy", "patient_id"],
    columns="metric",
    values="score",
)
table = pivoted_df.groupby(groupby_columns).agg(['mean', 'std'])

policy_ordering = ["Fixed", "ETC", "SH", "UCB", "TS"]

# Convert the 'policy' column in the MultiIndex to a Categorical type with the specified order
table = table.reset_index()
table['policy'] = pd.Categorical(table['policy'], categories=policy_ordering, ordered=True)

# Sort the DataFrame first by 'model' then by the now-ordered 'policy'
sorted_table = table.sort_values(by=['model', 'policy']).set_index(groupby_columns)[["Cumulative Regret (outcome)", "KL Divergence", "Simple Regret With Mean", "Length", "Best Arm Identification With Mean"]].rename(
    columns={"Cumulative Regret (outcome)": "Regret", "Simple Regret With Mean": "$\SIR$", "Best Arm Identification With Mean": "$\BAI$", "KL Divergence": "$\KLD$"},
)
sorted_table.index.names = ["S.", "Policy"]
sorted_table
Out[10]:
metric Regret $\KLD$ $\SIR$ Length $\BAI$
mean std mean std mean std mean std mean std
S. Policy
I Fixed -0.037432 3.597176 0.163983 0.165737 0.0 0.000000 16.52 8.818575 0.48 0.502117
ETC -0.292172 4.041548 0.166418 0.164108 0.0 0.000000 17.17 11.315543 0.48 0.502117
UCB 0.309196 4.454360 0.17932 0.170011 0.0 0.000000 18.45 12.693254 0.46 0.500908
TS -0.147157 4.328133 0.150667 0.159053 0.0 0.000000 22.42 14.843412 0.52 0.502117
II Fixed -6.732649 4.478590 0.219876 0.134231 0.04 0.196946 11.08 4.856902 0.96 0.196946
ETC -6.911719 4.506643 0.218885 0.135457 0.05 0.219043 11.27 4.921310 0.95 0.219043
UCB -6.502219 5.637445 0.403294 0.284822 0.14 0.348735 11.6 6.539700 0.86 0.348735
TS -12.266483 14.320192 0.411553 0.281109 0.22 0.416333 18.79 12.119185 0.78 0.416333
III Fixed -10.109516 2.809726 0.376703 0.399681 0.04 0.281411 7.43 1.289076 0.98 0.140705
ETC -10.243181 2.724291 0.354477 0.347126 0.02 0.200000 7.55 1.366075 0.99 0.100000
UCB -33.577203 40.023253 1.111306 0.969029 0.38 0.788554 19.59 18.100781 0.81 0.394277
TS -41.302591 42.305474 1.16179 1.023429 0.5 0.870388 25.07 18.014167 0.76 0.429235
In [11]:
with open('mt_resources/7-stopping/01-table.tex', 'w') as file:
    str = sorted_table.style.format(precision=2).to_latex(hrules=True)
    print(str)
    file.write(str)
\begin{tabular}{lllrlrlrlrlr}
\toprule
 & metric & \multicolumn{2}{r}{Regret} & \multicolumn{2}{r}{$\KLD$} & \multicolumn{2}{r}{$\SIR$} & \multicolumn{2}{r}{Length} & \multicolumn{2}{r}{$\BAI$} \\
 &  & mean & std & mean & std & mean & std & mean & std & mean & std \\
S. & Policy &  &  &  &  &  &  &  &  &  &  \\
\midrule
\multirow[c]{4}{*}{I} & Fixed & -0.04 & 3.60 & 0.16 & 0.17 & 0.00 & 0.00 & 16.52 & 8.82 & 0.48 & 0.50 \\
 & ETC & -0.29 & 4.04 & 0.17 & 0.16 & 0.00 & 0.00 & 17.17 & 11.32 & 0.48 & 0.50 \\
 & UCB & 0.31 & 4.45 & 0.18 & 0.17 & 0.00 & 0.00 & 18.45 & 12.69 & 0.46 & 0.50 \\
 & TS & -0.15 & 4.33 & 0.15 & 0.16 & 0.00 & 0.00 & 22.42 & 14.84 & 0.52 & 0.50 \\
\multirow[c]{4}{*}{II} & Fixed & -6.73 & 4.48 & 0.22 & 0.13 & 0.04 & 0.20 & 11.08 & 4.86 & 0.96 & 0.20 \\
 & ETC & -6.91 & 4.51 & 0.22 & 0.14 & 0.05 & 0.22 & 11.27 & 4.92 & 0.95 & 0.22 \\
 & UCB & -6.50 & 5.64 & 0.40 & 0.28 & 0.14 & 0.35 & 11.60 & 6.54 & 0.86 & 0.35 \\
 & TS & -12.27 & 14.32 & 0.41 & 0.28 & 0.22 & 0.42 & 18.79 & 12.12 & 0.78 & 0.42 \\
\multirow[c]{4}{*}{III} & Fixed & -10.11 & 2.81 & 0.38 & 0.40 & 0.04 & 0.28 & 7.43 & 1.29 & 0.98 & 0.14 \\
 & ETC & -10.24 & 2.72 & 0.35 & 0.35 & 0.02 & 0.20 & 7.55 & 1.37 & 0.99 & 0.10 \\
 & UCB & -33.58 & 40.02 & 1.11 & 0.97 & 0.38 & 0.79 & 19.59 & 18.10 & 0.81 & 0.39 \\
 & TS & -41.30 & 42.31 & 1.16 & 1.02 & 0.50 & 0.87 & 25.07 & 18.01 & 0.76 & 0.43 \\
\bottomrule
\end{tabular}

In [12]:
def rename_df(df):
    df["policy_#_metric_#_model_p"] = df["policy"].apply(lambda x: policy_mapping[x])
    return df

SeriesOfSimulationsData.plot_lines(
    [s["result"] for s in calculated_series if s["configuration"]["model"] == "NormalModel(([1, 0], [1, 1]))"],
    [
        CumulativeRegret(),
    ],
    legend_position=(0.02,0.3),
    process_df = rename_df,
)
plt.ylabel('Regret')
plt.savefig("mt_resources/7-stopping/01_cumulative_regret.pdf", bbox_inches="tight")
No description has been provided for this image
In [13]:
SeriesOfSimulationsData.plot_lines(
    [s["result"] for s in calculated_series if s["configuration"]["model"] == "NormalModel(([1, 0], [1, 1]))"],
    [
        SimpleRegretWithMean(),
    ],
    legend_position=(0.8,1.0),
    process_df = rename_df,
)
plt.ylabel('Simple Regret')
plt.savefig("mt_resources/7-stopping/01_simple_regret.pdf", bbox_inches="tight")
No description has been provided for this image
In [14]:
SeriesOfSimulationsData.plot_lines(
    [s["result"] for s in calculated_series if s["configuration"]["model"] == "NormalModel(([1, 0], [1, 1]))"],
    [
        KLDivergence(data_to_true_distribution = data_to_true_distribution, debug_data_to_posterior_distribution=debug_data_to_torch_distribution)
    ],
    legend_position=(0.8,1.0),
    process_df = rename_df,
)
plt.ylabel('KL Divergence')
plt.savefig("mt_resources/7-stopping/01-kl-divergence.pdf", bbox_inches="tight")
No description has been provided for this image
In [15]:
df = SeriesOfSimulationsData.score_data(
    [s["result"] for s in calculated_series if s["configuration"]["model"] == "NormalModel(([1, 0], [1, 1]))"],
    [ IsStopped() ],
)
df["policy"] = df["policy"].apply(lambda x: policy_mapping[x])
groupby_df_sum = df.groupby(["policy", "model", "t"]).sum()

ax = seaborn.lineplot(
    data=groupby_df_sum,
    x="t",
    y="score",
    hue="policy",
    # units="patient_id",
    #estimator=numpy.median,
    #errorbar=lambda x: (numpy.quantile(x, 0.25), numpy.quantile(x, 0.75)),
)
plt.ylabel("Number of patients")
seaborn.move_legend(ax, "upper right", title=None)
plt.savefig("mt_resources/7-stopping/01_is_stopped.pdf", bbox_inches="tight")
No description has been provided for this image
In [16]:
plot_allocations_for_calculated_series(calculated_series)
/opt/homebrew/Caskroom/miniconda/base/envs/mt/lib/python3.11/site-packages/holoviews/plotting/bokeh/plot.py:987: UserWarning: found multiple competing values for 'toolbar.active_drag' property; using the latest value
  layout_plot = gridplot(
/opt/homebrew/Caskroom/miniconda/base/envs/mt/lib/python3.11/site-packages/holoviews/plotting/bokeh/plot.py:987: UserWarning: found multiple competing values for 'toolbar.active_scroll' property; using the latest value
  layout_plot = gridplot(
Out[16]:
In [17]:
plot_allocations_for_calculated_series([s for s in calculated_series if s["configuration"]["policy"] == "StoppingPolicy(BlockPolicy(UpperConfidenceBound(0.05 epsilon, NormalKnownVariance(0, 1, 1))))" and s["configuration"]["model"] == "NormalModel(([0, 0], [1, 1]))"])
Out[17]:
:Layout
In [18]:
plot_allocations_for_calculated_series([s for s in calculated_series if s["configuration"]["policy"] == "StoppingPolicy(BlockPolicy(UpperConfidenceBound(0.05 epsilon, NormalKnownVariance(0, 1, 1))))" and s["configuration"]["model"] == "NormalModel(([1, 0], [1, 1]))"])
Out[18]:
:Layout
In [21]:
import param
from functools import reduce


class PatientExplorer(param.Parameterized):
    patient_id = param.Integer(default=0, bounds=(0, number_of_patients - 1))
    configuration = param.Integer(default=0, bounds=(0, len(calculated_series) - 1))
    t = param.Integer(default=10, bounds=(0, max_length - 1))

    @param.depends("patient_id", "configuration")
    def hvplot(self):
        debug_data = (
            calculated_series[self.configuration]["result"]
            .simulations[self.patient_id]
            .history.debug_data()
        )
        df = pandas.DataFrame([flatten_dictionary(d) for d in debug_data])
        return df.hvplot()

    @param.depends("configuration")
    def configuration_name(self):
        return panel.panel(calculated_series[self.configuration]["configuration"])
In [22]:
import holoviews
import panel
explorer = PatientExplorer()
hvplot = holoviews.DynamicMap(explorer.hvplot)
panel.Column(
    panel.Row(panel.Column(explorer.param, explorer.configuration_name), hvplot),
)
Out[22]:
In [ ]: